import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Linear, MLP
from copy import deepcopy
from itertools import accumulate
import numpy as np

class GINE_layer(nn.Module):
    def __init__(self, args, in_dim):
        super(GINE_layer, self).__init__()
        self.args = args
        self.hidden_dim = args.hidden_dim

        self.op__to__op = MLP([in_dim, args.hidden_dim, args.hidden_dim])
        self.op__to__m = MLP([in_dim + 1, args.hidden_dim, args.hidden_dim])
        self.m__to__op = MLP([in_dim + 1, args.hidden_dim, args.hidden_dim])
        self.m__to__m = MLP([in_dim, args.hidden_dim, args.hidden_dim])

    def forward(self, data):

        message_op_to_op = torch.mm(data['op_op_idx'], data['op_x'])
        message_m_to_op = torch.cat((torch.mm(torch.where(data['op_m_idx'] > 0., 1, 0).to(torch.float32), data['m_x']), torch.sum(data['op_m_idx'], dim=1, keepdim=True)), dim=1)
        message_op_to_m = torch.cat((torch.mm(torch.where(data['op_m_idx'] > 0., 1, 0).transpose(0, 1).to(torch.float32), data['op_x']), torch.sum(data['op_m_idx'].transpose(0, 1), dim=1, keepdim=True)), dim=1)
        message_m_to_m = torch.mm(data['m_m_idx'], data['m_x'])

        if message_op_to_op.size(0) == 1:
            data['op_x'] = self.op__to__op(torch.cat((message_op_to_op, message_op_to_op), dim=0)) + self.m__to__op(torch.cat((message_m_to_op, message_m_to_op), dim=0))
            data['op_x'] = data['op_x'][:1, :]
        else:
            data['op_x'] = self.op__to__op(message_op_to_op) + self.m__to__op(message_m_to_op)
        data['m_x']= self.m__to__m(message_m_to_m) + self.op__to__m(message_op_to_m)
        return data

class GNN(nn.Module):
    def __init__(self, args):
        super(GNN, self).__init__()
        self.args = args
        self.convs = nn.ModuleList()
        if args.delete_node == True:
            self.m_trans_fc = Linear(3, 5)
            self.convs.append(GINE_layer(args, 5))
        else:
            self.m_trans_fc = Linear(3, 6)
            self.convs.append(GINE_layer(args, 6))

        for i in range(args.GNN_num_layers - 1):
            self.convs.append(GINE_layer(args, args.hidden_dim))

        self.op_fc = Linear(args.hidden_dim, args.hidden_dim)
        self.m_fc = Linear(args.hidden_dim, args.hidden_dim)

    def forward(self, data):
        for key in data.keys():
            if key == "unfinish_op" or key == "machine_num":
                continue
            else:
                data[key] = torch.FloatTensor(data[key]).to(self.args.device)

        data['m_x'] = self.m_trans_fc(data['m_x'])

        for conv in self.convs:
            data = conv(data)
            data['op_x'], data['m_x'] = F.relu(data['op_x']), F.relu(data['m_x'])

        if data['op_x'].size(0) == 1: 
            data['op_x'] = self.op_fc(torch.cat((data['op_x'], data['op_x']), dim=0))
            data['op_x'] = data['op_x'][:1, :]
        else:
            data['op_x'] = self.op_fc(data['op_x'])
        data['m_x'] = self.m_fc(data['m_x'])

        return data